Add W4A8/W4A_FP8 MoE support with groupwise scale#202
Open
ClementLinCF wants to merge 8 commits intomainfrom
Open
Add W4A8/W4A_FP8 MoE support with groupwise scale#202ClementLinCF wants to merge 8 commits intomainfrom
ClementLinCF wants to merge 8 commits intomainfrom
Conversation
Contributor
There was a problem hiding this comment.
Pull request overview
This PR adds W4A8 (int4) and W4A_FP8 (int4_fp8) MoE support with groupwise scaling (group_size=32) to the FlyDSL fused MoE 2-stage kernel. It extends the existing int4_bf16 (W4A16) path with new load/unpack helpers and per-K32 group accumulation logic.
Changes:
- New
int4_fp8dtype support with FP8 activations + packed int4 weights, usingmfma_f32_16x16x32_fp8_fp8and in-kernel int4→fp8 conversion viacvt_pk_fp8_f32. - Groupwise scale (group_size=32) for all int4 weight variants (int4, int4_bf16, int4_fp8) with per-K32 fresh-accumulator + scale + running f32 accumulator pattern.
- New load/unpack helpers (
load_b_raw_w4a8_k64,load_b_raw_w4a8_groupwise_k64,unpack_b_w4a8,unpack_b_w4a_fp8, etc.) inmfma_preshuffle_pipeline.py.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| kernels/mfma_preshuffle_pipeline.py | New load/unpack helpers for W4A8, W4A_FP8, and groupwise scale variants |
| kernels/moe_gemm_2stage.py | Extended stage1/stage2 compile functions with int4_fp8 dtype and groupwise scale paths |
| tests/kernels/test_moe_gemm.py | Added int4_fp8 to test parameterization and corresponding quantization/routing logic |
| tests/test_common.py | Minor whitespace cleanup |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
coderfeli
reviewed
Mar 17, 2026
coderfeli
reviewed
Mar 17, 2026
coderfeli
reviewed
Mar 17, 2026
…tions - Extract _unpack_int4_to_int8_pair(): shared 7-op int4->int8 bit manipulation used by unpack_b_w4a16, unpack_b_w4a8, unpack_b_w4a_fp8, and load_b_pack_k32 (was copy-pasted in 4 places) - Extract _pack_i32_pair_to_i64(): shared (even, odd) -> i64 packing - Extract _load_groupwise_scale(): shared scale address calculation and buffer_load for W4A16 and W4A8 groupwise paths - Have load_b_raw_w4a8_groupwise_k64 delegate weight load to load_b_raw_w4a8_k64 (matching W4A16 groupwise pattern) - Replace ir.IntegerType.get_signless(32) / ir.F32Type.get() with T.i32 / T.f32 to follow project conventions - Replace arith.constant(..., index=True) with fx.Index(...) throughout
- Add 'bf16' to out_dtype parametrization (was only f16/f32) - Fix run_moe_stage2 to accept bf16 output dtype - Fix bytes_moved calculation to treat bf16 as 2-byte (like f16) The stage2 kernel (compile_moe_gemm2) already supports out_dtype='bf16' using bf16 global atomics on gfx94+/gfx95+, but the test harness blocked it. Verified all 24 new test cases pass on MI355X (gfx950).
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
The existing fused MoE 2-stage kernel supports fp8, fp16, bf16, int8, and W4A16 (int4_bf16) data types. This PR extends it with W4A8 (int4) and W4A_FP8 (int4_fp8) support, and adds groupwise scale (group_size=32) for all three int4 weight variants — enabling lower-precision MoE inference paths that are critical for production deployment of large MoE models (e.g., Kimi K2.5).
Technical Details
New dtype: int4_fp8 (W4A_FP8)
FP8 activations + packed int4 weights, using mfma_f32_16x16x32_fp8_fp8.
In-kernel int4→fp8 unpack via cvt_pk_fp8_f32 ROCDL intrinsic.
8-byte K64 weight loads (buffer_load_dwordx2) for improved memory efficiency.
Groupwise scale (group_size=32) for W4A8/W4A16/W4A_FP8
Per-K32 groupwise accumulation: fresh MFMA accumulator + per-group scale + running f32 accumulator, for both stage1 and stage2.
Groupwise scale address formula using [E, num_groups, N] layout with preshuffled scale tensors.
Epilogue correctly skips sitofp for groupwise accumulators (already f32 from per-K32 accumulation).
bf16 output dtype test coverage
out_dtype="bf16"to test parametrization (was only f16/f32).run_moe_stage2test helper to accept bf16 output dtype.bytes_movedbandwidth calculation to treat bf16 as 2-byte.Test Plan
pytest tests/kernels/test_moe_gemm.py::test_moe_gemm_2stage covering:
Each test verifies correctness against torch reference (for S/M shapes) and runs perf timing.
Test Results
MI308 (gfx942) — Original tests
120 passed, 360 skipped, 288 deselected in 35.39s
MI355X (gfx950) — Full test suite
264 passed, 504 skipped, 0 failed in 3m43s
All applicable tests pass on gfx950. Skips are expected:
masktests (valid_mask not supported on gfx950)graphmode tests (HIP graph capture skipped)out_f32 reducetests (accumulate=False forbids it)MI355X (gfx950) — bf16 output tests (new)
24/24 passed — all input dtypes × S/M/L sizes with
out_dtype=bf16Peak performance on MI355X (L size):
E2E test
Kimi-K2.5 W4A16, W4A8 on MI308
Build & Test Environment (MI355X)
Built and tested inside Docker container (
lmsysorg/sglang-rocm:v0.5.9-rocm700-mi35x-20260322):7f77ca0dbda4) with 128 threadspip install -e .→flydsl 0.1.1.dev413)Note: The PyPI wheel
flydsl==0.1.1is incompatible with this branch due to commit4d84ee8(native idx2crd APIs). Building from source is required.Submission Checklist